From 4e5ff7f53d2213303958f50bba06a4a9dff83e6a Mon Sep 17 00:00:00 2001 From: Vincent Hanquez Date: Sat, 25 Jan 2014 16:51:51 +0000 Subject: [PATCH] Change the way parameters are created. This is still WIP and this commit is truly horrific. Sadly, it's just too much effort to do clean commit with this, and it doesn't mix with experimentation either. --- core/Benchmarks/Benchmarks.hs | 25 +- core/Network/TLS.hs | 24 +- core/Network/TLS/Context.hs | 337 ++++++---------------- core/Network/TLS/Context/Internal.hs | 224 ++++++++++++++ core/Network/TLS/Core.hs | 12 +- core/Network/TLS/Handshake.hs | 11 +- core/Network/TLS/Handshake/Certificate.hs | 5 +- core/Network/TLS/Handshake/Client.hs | 69 +++-- core/Network/TLS/Handshake/Common.hs | 31 +- core/Network/TLS/Handshake/Key.hs | 2 +- core/Network/TLS/Handshake/Process.hs | 2 +- core/Network/TLS/Handshake/Server.hs | 32 +- core/Network/TLS/Handshake/Signature.hs | 2 +- core/Network/TLS/Hooks.hs | 7 +- core/Network/TLS/IO.hs | 3 +- core/Network/TLS/Parameters.hs | 316 +++++++++++--------- core/Network/TLS/Receiving.hs | 2 +- core/Network/TLS/Sending.hs | 5 +- core/Network/TLS/Types.hs | 2 +- core/Network/TLS/X509.hs | 16 + core/Tests/Connection.hs | 66 +++-- core/Tests/Tests.hs | 8 +- core/tls.cabal | 7 + 23 files changed, 697 insertions(+), 511 deletions(-) create mode 100644 core/Network/TLS/Context/Internal.hs diff --git a/core/Benchmarks/Benchmarks.hs b/core/Benchmarks/Benchmarks.hs index 9cea486..903ba1d 100644 --- a/core/Benchmarks/Benchmarks.hs +++ b/core/Benchmarks/Benchmarks.hs @@ -8,6 +8,8 @@ import Criterion.Main import Control.Concurrent.Chan import Network.TLS import Data.X509 +import Data.X509.Validation +import Data.Default.Class import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as L @@ -15,15 +17,22 @@ import qualified Data.ByteString.Lazy as L recvDataNonNull ctx = recvData ctx >>= \l -> if B.null l then recvDataNonNull ctx else return l getParams connectVer cipher = (cParams, sParams) - where sParams = defaultParamsServer - { pAllowedVersions = [connectVer] - , pCiphers = [cipher] - , pCredentials = Credentials [ (CertificateChain [simpleX509 $ PubKeyRSA pubKey], PrivKeyRSA privKey) ] - } - cParams = defaultParamsClient - { pAllowedVersions = [connectVer] - , pCiphers = [cipher] + where sParams = def { serverSupported = supported + , serverShared = def { + sharedCredentials = Credentials [ (CertificateChain [simpleX509 $ PubKeyRSA pubKey], PrivKeyRSA privKey) ] + } + } + cParams = (defaultParamsClient "" B.empty) + { clientSupported = supported + , clientShared = def { sharedValidationCache = ValidationCache + { cacheAdd = \_ _ _ -> return () + , cacheQuery = \_ _ _ -> return ValidationCachePass + } + } } + supported = def { supportedCiphers = [cipher] + , supportedVersions = [connectVer] + } (pubKey, privKey) = getGlobalRSAPair runTLSPipe params tlsServer tlsClient d name = bench name $ do diff --git a/core/Network/TLS.hs b/core/Network/TLS.hs index 8a4a540..13a6778 100644 --- a/core/Network/TLS.hs +++ b/core/Network/TLS.hs @@ -8,19 +8,17 @@ module Network.TLS ( -- * Context configuration - Params(..) - , RoleParams(..) - , ClientParams(..) + ClientParams(..) , ServerParams(..) - , updateClientParams - , updateServerParams + , ClientHooks(..) + , ServerHooks(..) + , Supported(..) + , Shared(..) , Logging(..) , Measurement(..) , CertificateUsage(..) , CertificateRejectReason(..) , defaultParamsClient - , defaultParamsServer - , defaultLogging , MaxFragmentEnum(..) , HashAndSignatureAlgorithm , HashAlgorithm(..) @@ -36,7 +34,6 @@ module Network.TLS , SessionData(..) , SessionManager(..) , noSessionManager - , setSessionManager -- * Backend abstraction , Backend(..) @@ -94,18 +91,27 @@ module Network.TLS -- * Exceptions , TLSException(..) + + -- * Validation Cache + , ValidationCache + , exceptionValidationCache ) where import Network.TLS.Backend (Backend(..)) -import Network.TLS.Struct ( Version(..), TLSError(..), TLSException(..) +import Network.TLS.Struct ( TLSError(..), TLSException(..) , HashAndSignatureAlgorithm, HashAlgorithm(..), SignatureAlgorithm(..) , Header(..), ProtocolType(..), CertificateType(..) , AlertDescription(..)) import Network.TLS.Crypto (KxError(..)) import Network.TLS.Cipher +import Network.TLS.Hooks +import Network.TLS.Measurement import Network.TLS.Credentials import Network.TLS.Compression (CompressionC(..), Compression(..), nullCompression) import Network.TLS.Context +import Network.TLS.Parameters import Network.TLS.Core import Network.TLS.Session +import Network.TLS.X509 +import Network.TLS.Types import Data.X509 (PubKey(..), PrivKey(..)) diff --git a/core/Network/TLS/Context.hs b/core/Network/TLS/Context.hs index 225b7af..41971b7 100644 --- a/core/Network/TLS/Context.hs +++ b/core/Network/TLS/Context.hs @@ -8,44 +8,18 @@ module Network.TLS.Context ( -- * Context configuration - Params(..) - , RoleParams(..) - , ClientParams(..) - , ServerParams(..) - , updateClientParams - , updateServerParams - , Logging(..) - , SessionID - , SessionData(..) - , MaxFragmentEnum(..) - , Measurement(..) - , CertificateUsage(..) - , CertificateRejectReason(..) - , defaultLogging - , defaultParamsClient - , defaultParamsServer - , withSessionManager - , setSessionManager - , getClientParams - , getServerParams - , credentialsGet + TLSParams -- * Context object and accessor - , Context + , Context(..) , Hooks(..) - , ctxParams - , ctxConnection , ctxEOF , ctxHasSSLv2ClientHello , ctxDisableSSLv2ClientHello , ctxEstablished - , ctxCiphers , ctxLogging , ctxWithHooks - , ctxRxState - , ctxTxState - , ctxHandshake - , ctxNeedEmptyPacket + , modifyHooks , setEOF , setEstablished , contextFlush @@ -63,13 +37,6 @@ module Network.TLS.Context , Information(..) , contextGetInformation - -- * deprecated types - , TLSParams - , TLSLogging - , TLSCertificateUsage - , TLSCertificateRejectReason - , TLSCtx - -- * New contexts , contextNew -- * Deprecated new contexts methods @@ -91,177 +58,55 @@ module Network.TLS.Context ) where import Network.TLS.Backend -import Network.TLS.Extension +import Network.TLS.Context.Internal import Network.TLS.Struct import Network.TLS.Cipher (Cipher(..), CipherKeyExchangeType(..)) -import Network.TLS.Compression (Compression) import Network.TLS.Credentials import Network.TLS.State -import Network.TLS.Handshake.State import Network.TLS.Hooks import Network.TLS.Record.State import Network.TLS.Parameters import Network.TLS.Measurement import Network.TLS.Types (Role(..)) +import Network.TLS.Handshake (handshakeClient, handshakeClientWith, handshakeServer, handshakeServerWith) import Data.Maybe (isJust) -import qualified Data.ByteString as B import Crypto.Random import Control.Concurrent.MVar import Control.Monad.State -import Control.Exception (throwIO, Exception()) import Data.IORef -import Data.Tuple -- deprecated imports import Network.Socket (Socket) import System.IO (Handle) --- | Information related to a running context, e.g. current cipher -data Information = Information - { infoVersion :: Version - , infoCipher :: Cipher - , infoCompression :: Compression - } deriving (Show,Eq) --- | A TLS Context keep tls specific state, parameters and backend information. -data Context = Context - { ctxConnection :: Backend -- ^ return the backend object associated with this context - , ctxParams :: Params - , ctxCiphers :: [Cipher] -- ^ prepared list of allowed ciphers according to parameters - , ctxState :: MVar TLSState - , ctxMeasurement :: IORef Measurement - , ctxEOF_ :: IORef Bool -- ^ has the handle EOFed or not. - , ctxEstablished_ :: IORef Bool -- ^ has the handshake been done and been successful. - , ctxNeedEmptyPacket :: IORef Bool -- ^ empty packet workaround for CBC guessability. - , ctxSSLv2ClientHello :: IORef Bool -- ^ enable the reception of compatibility SSLv2 client hello. - -- the flag will be set to false regardless of its initial value - -- after the first packet received. - , ctxTxState :: MVar RecordState -- ^ current tx state - , ctxRxState :: MVar RecordState -- ^ current rx state - , ctxHandshake :: MVar (Maybe HandshakeState) -- ^ optional handshake state - , ctxHooks :: IORef Hooks -- ^ hooks for this context - , ctxLockWrite :: MVar () -- ^ lock to use for writing data (including updating the state) - , ctxLockRead :: MVar () -- ^ lock to use for reading data (including updating the state) - , ctxLockState :: MVar () -- ^ lock used during read/write when receiving and sending packet. - -- it is usually nested in a write or read lock. - } +class TLSParams a where + getTLSCommonParams :: a -> CommonParams + getTLSRole :: a -> Role + getCiphers :: a -> [Cipher] + doHandshake :: a -> Context -> IO () + doHandshakeWith :: a -> Context -> Handshake -> IO () --- deprecated types, setup as aliases for compatibility. -type TLSParams = Params -type TLSCtx = Context -type TLSLogging = Logging -type TLSCertificateUsage = CertificateUsage -type TLSCertificateRejectReason = CertificateRejectReason +instance TLSParams ClientParams where + getTLSCommonParams cparams = ( clientSupported cparams + , clientShared cparams + , clientCommonHooks cparams + ) + getTLSRole _ = ClientRole + getCiphers cparams = supportedCiphers $ clientSupported cparams + doHandshake = handshakeClient + doHandshakeWith = handshakeClientWith -updateMeasure :: Context -> (Measurement -> Measurement) -> IO () -updateMeasure ctx f = do - x <- readIORef (ctxMeasurement ctx) - writeIORef (ctxMeasurement ctx) $! f x - -withMeasure :: Context -> (Measurement -> IO a) -> IO a -withMeasure ctx f = readIORef (ctxMeasurement ctx) >>= f - -contextFlush :: Context -> IO () -contextFlush = backendFlush . ctxConnection - -contextClose :: Context -> IO () -contextClose = backendClose . ctxConnection - --- | Information about the current context -contextGetInformation :: Context -> IO (Maybe Information) -contextGetInformation ctx = do - ver <- usingState_ ctx $ gets stVersion - (cipher,comp) <- failOnEitherError $ runRxState ctx $ gets $ \st -> (stCipher st, stCompression st) - case (ver, cipher) of - (Just v, Just c) -> return $ Just $ Information v c comp - _ -> return Nothing - -contextSend :: Context -> Bytes -> IO () -contextSend c b = updateMeasure c (addBytesSent $ B.length b) >> (backendSend $ ctxConnection c) b - -contextRecv :: Context -> Int -> IO Bytes -contextRecv c sz = updateMeasure c (addBytesReceived sz) >> (backendRecv $ ctxConnection c) sz - -ctxEOF :: Context -> IO Bool -ctxEOF ctx = readIORef $ ctxEOF_ ctx - -ctxHasSSLv2ClientHello :: Context -> IO Bool -ctxHasSSLv2ClientHello ctx = readIORef $ ctxSSLv2ClientHello ctx - -ctxDisableSSLv2ClientHello :: Context -> IO () -ctxDisableSSLv2ClientHello ctx = writeIORef (ctxSSLv2ClientHello ctx) False - -setEOF :: Context -> IO () -setEOF ctx = writeIORef (ctxEOF_ ctx) True - -ctxEstablished :: Context -> IO Bool -ctxEstablished ctx = readIORef $ ctxEstablished_ ctx - -ctxWithHooks :: Context -> (Hooks -> IO a) -> IO a -ctxWithHooks ctx f = readIORef (ctxHooks ctx) >>= f - -setEstablished :: Context -> Bool -> IO () -setEstablished ctx v = writeIORef (ctxEstablished_ ctx) v - -ctxLogging :: Context -> Logging -ctxLogging = pLogging . ctxParams - --- | create a new context using the backend and parameters specified. -contextNew :: (MonadIO m, CPRG rng, HasBackend backend) - => backend -- ^ Backend abstraction with specific method to interact with the connection type. - -> Params -- ^ Parameters of the context. - -> rng -- ^ Random number generator associated with this context. - -> m Context -contextNew backend params rng = liftIO $ do - initializeBackend backend - let role = case roleParams params of - Client {} -> ClientRole - Server {} -> ServerRole - let st = newTLSState rng role - - stvar <- newMVar st - eof <- newIORef False - established <- newIORef False - stats <- newIORef newMeasurement - -- we enable the reception of SSLv2 ClientHello message only in the - -- server context, where we might be dealing with an old/compat client. - sslv2Compat <- newIORef (role == ServerRole) - needEmptyPacket <- newIORef False - hooks <- newIORef defaultHooks - tx <- newMVar newRecordState - rx <- newMVar newRecordState - hs <- newMVar Nothing +instance TLSParams ServerParams where + getTLSCommonParams sparams = ( serverSupported sparams + , serverShared sparams + , serverCommonHooks sparams + ) + getTLSRole _ = ServerRole -- on the server we filter our allowed ciphers here according -- to the credentials and DHE parameters loaded - let ciphers = case roleParams params of - Client {} -> pCiphers params - Server sParams -> filterServer sParams $ pCiphers params - lockWrite <- newMVar () - lockRead <- newMVar () - lockState <- newMVar () - - when (null ciphers) $ error "no ciphers available with those parameters" - - return $ Context - { ctxConnection = getBackend backend - , ctxParams = params - , ctxCiphers = ciphers - , ctxState = stvar - , ctxTxState = tx - , ctxRxState = rx - , ctxHandshake = hs - , ctxMeasurement = stats - , ctxEOF_ = eof - , ctxEstablished_ = established - , ctxSSLv2ClientHello = sslv2Compat - , ctxNeedEmptyPacket = needEmptyPacket - , ctxHooks = hooks - , ctxLockWrite = lockWrite - , ctxLockRead = lockRead - , ctxLockState = lockState - } - where filterServer sParams ciphers = filter authorizedCKE ciphers + getCiphers sparams = filter authorizedCKE (supportedCiphers $ serverSupported sparams) where authorizedCKE cipher = case cipherKeyExchange cipher of CipherKeyExchange_RSA -> canEncryptRSA @@ -277,26 +122,83 @@ contextNew backend params rng = liftIO $ do CipherKeyExchange_ECDH_RSA -> False CipherKeyExchange_ECDHE_ECDSA -> False - canDHE = isJust $ serverDHEParams sParams + canDHE = isJust $ serverDHEParams sparams canSignDSS = SignatureDSS `elem` signingAlgs canSignRSA = SignatureRSA `elem` signingAlgs canEncryptRSA = isJust $ credentialsFindForDecrypting creds signingAlgs = credentialsListSigningAlgorithms creds - creds = credentialsGet params + creds = sharedCredentials $ serverShared sparams + doHandshake = handshakeServer + doHandshakeWith = handshakeServerWith + +-- | create a new context using the backend and parameters specified. +contextNew :: (MonadIO m, CPRG rng, HasBackend backend, TLSParams params) + => backend -- ^ Backend abstraction with specific method to interact with the connection type. + -> params -- ^ Parameters of the context. + -> rng -- ^ Random number generator associated with this context. + -> m Context +contextNew backend params rng = liftIO $ do + initializeBackend backend + + let role = getTLSRole params + st = newTLSState rng role + (supported, shared, commonHooks) = getTLSCommonParams params + ciphers = getCiphers params + + when (null ciphers) $ error "no ciphers available with those parameters" + + stvar <- newMVar st + eof <- newIORef False + established <- newIORef False + stats <- newIORef newMeasurement + -- we enable the reception of SSLv2 ClientHello message only in the + -- server context, where we might be dealing with an old/compat client. + sslv2Compat <- newIORef (role == ServerRole) + needEmptyPacket <- newIORef False + hooks <- newIORef defaultHooks + tx <- newMVar newRecordState + rx <- newMVar newRecordState + hs <- newMVar Nothing + lockWrite <- newMVar () + lockRead <- newMVar () + lockState <- newMVar () + + return $ Context + { ctxConnection = getBackend backend + , ctxShared = shared + , ctxSupported = supported + , ctxCommonHooks = commonHooks + , ctxCiphers = ciphers + , ctxState = stvar + , ctxTxState = tx + , ctxRxState = rx + , ctxHandshake = hs + , ctxDoHandshake = doHandshake params + , ctxDoHandshakeWith = doHandshakeWith params + , ctxMeasurement = stats + , ctxEOF_ = eof + , ctxEstablished_ = established + , ctxSSLv2ClientHello = sslv2Compat + , ctxNeedEmptyPacket = needEmptyPacket + , ctxHooks = hooks + , ctxLockWrite = lockWrite + , ctxLockRead = lockRead + , ctxLockState = lockState + } -- | create a new context on an handle. -contextNewOnHandle :: (MonadIO m, CPRG rng) +contextNewOnHandle :: (MonadIO m, CPRG rng, TLSParams params) => Handle -- ^ Handle of the connection. - -> Params -- ^ Parameters of the context. + -> params -- ^ Parameters of the context. -> rng -- ^ Random number generator associated with this context. -> m Context contextNewOnHandle handle params st = contextNew handle params st {-# DEPRECATED contextNewOnHandle "use contextNew" #-} -- | create a new context on a socket. -contextNewOnSocket :: (MonadIO m, CPRG rng) +contextNewOnSocket :: (MonadIO m, CPRG rng, TLSParams params) => Socket -- ^ Socket of the connection. - -> Params -- ^ Parameters of the context. + -> params -- ^ Parameters of the context. -> rng -- ^ Random number generator associated with this context. -> m Context contextNewOnSocket sock params st = contextNew sock params st @@ -305,62 +207,3 @@ contextNewOnSocket sock params st = contextNew sock params st contextHookSetHandshakeRecv :: Context -> (Handshake -> IO Handshake) -> IO () contextHookSetHandshakeRecv context f = liftIO $ modifyIORef (ctxHooks context) (\hooks -> hooks { hookRecvHandshake = f }) - -throwCore :: (MonadIO m, Exception e) => e -> m a -throwCore = liftIO . throwIO - -failOnEitherError :: MonadIO m => m (Either TLSError a) -> m a -failOnEitherError f = do - ret <- f - case ret of - Left err -> throwCore err - Right r -> return r - -usingState :: Context -> TLSSt a -> IO (Either TLSError a) -usingState ctx f = - modifyMVar (ctxState ctx) $ \st -> - let (a, newst) = runTLSState f st - in newst `seq` return (newst, a) - -usingState_ :: Context -> TLSSt a -> IO a -usingState_ ctx f = failOnEitherError $ usingState ctx f - -usingHState :: Context -> HandshakeM a -> IO a -usingHState ctx f = liftIO $ modifyMVar (ctxHandshake ctx) $ \mst -> - case mst of - Nothing -> throwCore $ Error_Misc "missing handshake" - Just st -> return $ swap (Just `fmap` runHandshake st f) - -getHState :: Context -> IO (Maybe HandshakeState) -getHState ctx = liftIO $ readMVar (ctxHandshake ctx) - -runTxState :: Context -> RecordM a -> IO (Either TLSError a) -runTxState ctx f = do - ver <- usingState_ ctx (getVersionWithDefault $ maximum $ pAllowedVersions $ ctxParams ctx) - modifyMVar (ctxTxState ctx) $ \st -> - case runRecordM f ver st of - Left err -> return (st, Left err) - Right (a, newSt) -> return (newSt, Right a) - -runRxState :: Context -> RecordM a -> IO (Either TLSError a) -runRxState ctx f = do - ver <- usingState_ ctx getVersion - modifyMVar (ctxRxState ctx) $ \st -> - case runRecordM f ver st of - Left err -> return (st, Left err) - Right (a, newSt) -> return (newSt, Right a) - -getStateRNG :: Context -> Int -> IO Bytes -getStateRNG ctx n = usingState_ ctx $ genRandom n - -withReadLock :: Context -> IO a -> IO a -withReadLock ctx f = withMVar (ctxLockRead ctx) (const f) - -withWriteLock :: Context -> IO a -> IO a -withWriteLock ctx f = withMVar (ctxLockWrite ctx) (const f) - -withRWLock :: Context -> IO a -> IO a -withRWLock ctx f = withReadLock ctx $ withWriteLock ctx f - -withStateLock :: Context -> IO a -> IO a -withStateLock ctx f = withMVar (ctxLockState ctx) (const f) diff --git a/core/Network/TLS/Context/Internal.hs b/core/Network/TLS/Context/Internal.hs new file mode 100644 index 0000000..0488c34 --- /dev/null +++ b/core/Network/TLS/Context/Internal.hs @@ -0,0 +1,224 @@ +-- | +-- Module : Network.TLS.Context.Internal +-- License : BSD-style +-- Maintainer : Vincent Hanquez +-- Stability : experimental +-- Portability : unknown +-- +module Network.TLS.Context.Internal + ( + -- * Context configuration + ClientParams(..) + , ServerParams(..) + , defaultParamsClient + , SessionID + , SessionData(..) + , MaxFragmentEnum(..) + , Measurement(..) + + -- * Context object and accessor + , Context(..) + , Hooks(..) + , ctxEOF + , ctxHasSSLv2ClientHello + , ctxDisableSSLv2ClientHello + , ctxEstablished + , ctxLogging + , ctxWithHooks + , modifyHooks + , setEOF + , setEstablished + , contextFlush + , contextClose + , contextSend + , contextRecv + , updateMeasure + , withMeasure + , withReadLock + , withWriteLock + , withStateLock + , withRWLock + + -- * information + , Information(..) + , contextGetInformation + + -- * Using context states + , throwCore + , usingState + , usingState_ + , runTxState + , runRxState + , usingHState + , getHState + , getStateRNG + ) where + +import Network.TLS.Backend +import Network.TLS.Extension +import Network.TLS.Cipher +import Network.TLS.Struct +import Network.TLS.Compression (Compression) +import Network.TLS.State +import Network.TLS.Handshake.State +import Network.TLS.Hooks +import Network.TLS.Record.State +import Network.TLS.Parameters +import Network.TLS.Measurement +import qualified Data.ByteString as B + +import Control.Concurrent.MVar +import Control.Monad.State +import Control.Exception (throwIO, Exception()) +import Data.IORef +import Data.Tuple + + +-- | Information related to a running context, e.g. current cipher +data Information = Information + { infoVersion :: Version + , infoCipher :: Cipher + , infoCompression :: Compression + } deriving (Show,Eq) + +-- | A TLS Context keep tls specific state, parameters and backend information. +data Context = Context + { ctxConnection :: Backend -- ^ return the backend object associated with this context + , ctxSupported :: Supported + , ctxShared :: Shared + , ctxCommonHooks :: CommonHooks + , ctxCiphers :: [Cipher] -- ^ prepared list of allowed ciphers according to parameters + , ctxState :: MVar TLSState + , ctxMeasurement :: IORef Measurement + , ctxEOF_ :: IORef Bool -- ^ has the handle EOFed or not. + , ctxEstablished_ :: IORef Bool -- ^ has the handshake been done and been successful. + , ctxNeedEmptyPacket :: IORef Bool -- ^ empty packet workaround for CBC guessability. + , ctxSSLv2ClientHello :: IORef Bool -- ^ enable the reception of compatibility SSLv2 client hello. + -- the flag will be set to false regardless of its initial value + -- after the first packet received. + , ctxTxState :: MVar RecordState -- ^ current tx state + , ctxRxState :: MVar RecordState -- ^ current rx state + , ctxHandshake :: MVar (Maybe HandshakeState) -- ^ optional handshake state + , ctxDoHandshake :: Context -> IO () + , ctxDoHandshakeWith :: Context -> Handshake -> IO () + , ctxHooks :: IORef Hooks -- ^ hooks for this context + , ctxLockWrite :: MVar () -- ^ lock to use for writing data (including updating the state) + , ctxLockRead :: MVar () -- ^ lock to use for reading data (including updating the state) + , ctxLockState :: MVar () -- ^ lock used during read/write when receiving and sending packet. + -- it is usually nested in a write or read lock. + } + +updateMeasure :: Context -> (Measurement -> Measurement) -> IO () +updateMeasure ctx f = do + x <- readIORef (ctxMeasurement ctx) + writeIORef (ctxMeasurement ctx) $! f x + +withMeasure :: Context -> (Measurement -> IO a) -> IO a +withMeasure ctx f = readIORef (ctxMeasurement ctx) >>= f + +contextFlush :: Context -> IO () +contextFlush = backendFlush . ctxConnection + +contextClose :: Context -> IO () +contextClose = backendClose . ctxConnection + +-- | Information about the current context +contextGetInformation :: Context -> IO (Maybe Information) +contextGetInformation ctx = do + ver <- usingState_ ctx $ gets stVersion + (cipher,comp) <- failOnEitherError $ runRxState ctx $ gets $ \st -> (stCipher st, stCompression st) + case (ver, cipher) of + (Just v, Just c) -> return $ Just $ Information v c comp + _ -> return Nothing + +contextSend :: Context -> Bytes -> IO () +contextSend c b = updateMeasure c (addBytesSent $ B.length b) >> (backendSend $ ctxConnection c) b + +contextRecv :: Context -> Int -> IO Bytes +contextRecv c sz = updateMeasure c (addBytesReceived sz) >> (backendRecv $ ctxConnection c) sz + +ctxEOF :: Context -> IO Bool +ctxEOF ctx = readIORef $ ctxEOF_ ctx + +ctxHasSSLv2ClientHello :: Context -> IO Bool +ctxHasSSLv2ClientHello ctx = readIORef $ ctxSSLv2ClientHello ctx + +ctxDisableSSLv2ClientHello :: Context -> IO () +ctxDisableSSLv2ClientHello ctx = writeIORef (ctxSSLv2ClientHello ctx) False + +setEOF :: Context -> IO () +setEOF ctx = writeIORef (ctxEOF_ ctx) True + +ctxEstablished :: Context -> IO Bool +ctxEstablished ctx = readIORef $ ctxEstablished_ ctx + +ctxWithHooks :: Context -> (Hooks -> IO a) -> IO a +ctxWithHooks ctx f = readIORef (ctxHooks ctx) >>= f + +modifyHooks :: Context -> (Hooks -> Hooks) -> IO () +modifyHooks ctx f = modifyIORef (ctxHooks ctx) f + +setEstablished :: Context -> Bool -> IO () +setEstablished ctx v = writeIORef (ctxEstablished_ ctx) v + +ctxLogging :: Context -> Logging +ctxLogging = logging . ctxCommonHooks + +throwCore :: (MonadIO m, Exception e) => e -> m a +throwCore = liftIO . throwIO + +failOnEitherError :: MonadIO m => m (Either TLSError a) -> m a +failOnEitherError f = do + ret <- f + case ret of + Left err -> throwCore err + Right r -> return r + +usingState :: Context -> TLSSt a -> IO (Either TLSError a) +usingState ctx f = + modifyMVar (ctxState ctx) $ \st -> + let (a, newst) = runTLSState f st + in newst `seq` return (newst, a) + +usingState_ :: Context -> TLSSt a -> IO a +usingState_ ctx f = failOnEitherError $ usingState ctx f + +usingHState :: Context -> HandshakeM a -> IO a +usingHState ctx f = liftIO $ modifyMVar (ctxHandshake ctx) $ \mst -> + case mst of + Nothing -> throwCore $ Error_Misc "missing handshake" + Just st -> return $ swap (Just `fmap` runHandshake st f) + +getHState :: Context -> IO (Maybe HandshakeState) +getHState ctx = liftIO $ readMVar (ctxHandshake ctx) + +runTxState :: Context -> RecordM a -> IO (Either TLSError a) +runTxState ctx f = do + ver <- usingState_ ctx (getVersionWithDefault $ maximum $ supportedVersions $ ctxSupported ctx) + modifyMVar (ctxTxState ctx) $ \st -> + case runRecordM f ver st of + Left err -> return (st, Left err) + Right (a, newSt) -> return (newSt, Right a) + +runRxState :: Context -> RecordM a -> IO (Either TLSError a) +runRxState ctx f = do + ver <- usingState_ ctx getVersion + modifyMVar (ctxRxState ctx) $ \st -> + case runRecordM f ver st of + Left err -> return (st, Left err) + Right (a, newSt) -> return (newSt, Right a) + +getStateRNG :: Context -> Int -> IO Bytes +getStateRNG ctx n = usingState_ ctx $ genRandom n + +withReadLock :: Context -> IO a -> IO a +withReadLock ctx f = withMVar (ctxLockRead ctx) (const f) + +withWriteLock :: Context -> IO a -> IO a +withWriteLock ctx f = withMVar (ctxLockWrite ctx) (const f) + +withRWLock :: Context -> IO a -> IO a +withRWLock ctx f = withReadLock ctx $ withWriteLock ctx f + +withStateLock :: Context -> IO a -> IO a +withStateLock ctx f = withMVar (ctxLockState ctx) (const f) diff --git a/core/Network/TLS/Core.hs b/core/Network/TLS/Core.hs index 8c36bd6..e2909cb 100644 --- a/core/Network/TLS/Core.hs +++ b/core/Network/TLS/Core.hs @@ -29,6 +29,7 @@ module Network.TLS.Core import Network.TLS.Context import Network.TLS.Struct import Network.TLS.State (getSession) +import Network.TLS.Parameters import Network.TLS.IO import Network.TLS.Session import Network.TLS.Handshake @@ -79,17 +80,22 @@ recvData ctx = liftIO $ do terminate err AlertLevel_Fatal InternalError (show err) process (Handshake [ch@(ClientHello {})]) = - -- on server context receiving a client hello == renegotiation + withRWLock ctx ((ctxDoHandshakeWith ctx) ctx ch) >> recvData ctx + {- case roleParams $ ctxParams ctx of Server sparams -> withRWLock ctx (handshakeServerWith sparams ctx ch) >> recvData ctx Client {} -> let reason = "unexpected client hello in client context" in terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason - process (Handshake [HelloRequest]) = + -} + process (Handshake [hr@HelloRequest]) = + withRWLock ctx ((ctxDoHandshakeWith ctx) ctx hr) >> recvData ctx + {- -- on client context, receiving a hello request == renegotiation case roleParams $ ctxParams ctx of Server {} -> let reason = "unexpected hello request in server context" in terminate (Error_Misc reason) AlertLevel_Fatal UnexpectedMessage reason Client cparams -> withRWLock ctx (handshakeClient cparams ctx) >> recvData ctx + -} process (Alert [(AlertLevel_Warning, CloseNotify)]) = tryBye >> setEOF ctx >> return B.empty process (Alert [(AlertLevel_Fatal, desc)]) = do @@ -107,7 +113,7 @@ recvData ctx = liftIO $ do session <- usingState_ ctx getSession case session of Session Nothing -> return () - Session (Just sid) -> withSessionManager (ctxParams ctx) (\s -> sessionInvalidate s sid) + Session (Just sid) -> sessionInvalidate (sharedSessionManager $ ctxShared ctx) sid catchException (sendPacket ctx $ Alert [(level, desc)]) (\_ -> return ()) setEOF ctx E.throwIO (Terminated False reason err) diff --git a/core/Network/TLS/Handshake.hs b/core/Network/TLS/Handshake.hs index 2baa77a..f232a31 100644 --- a/core/Network/TLS/Handshake.hs +++ b/core/Network/TLS/Handshake.hs @@ -7,11 +7,13 @@ -- module Network.TLS.Handshake ( handshake + , handshakeClientWith , handshakeServerWith , handshakeClient + , handshakeServer ) where -import Network.TLS.Context +import Network.TLS.Context.Internal import Network.TLS.Struct import Network.TLS.IO import Network.TLS.Util (catchException) @@ -26,11 +28,8 @@ import Control.Exception (fromException) -- | Handshake for a new TLS connection -- This is to be called at the beginning of a connection, and during renegotiation handshake :: MonadIO m => Context -> m () -handshake ctx = do - let handshakeF = case roleParams $ ctxParams ctx of - Server sparams -> handshakeServer sparams - Client cparams -> handshakeClient cparams - liftIO $ handleException $ withRWLock ctx (handshakeF ctx) +handshake ctx = + liftIO $ handleException $ withRWLock ctx (ctxDoHandshake ctx $ ctx) where handleException f = catchException f $ \exception -> do let tlserror = maybe (Error_Misc $ show exception) id $ fromException exception setEstablished ctx False diff --git a/core/Network/TLS/Handshake/Certificate.hs b/core/Network/TLS/Handshake/Certificate.hs index b57cbaf..730130b 100644 --- a/core/Network/TLS/Handshake/Certificate.hs +++ b/core/Network/TLS/Handshake/Certificate.hs @@ -10,8 +10,9 @@ module Network.TLS.Handshake.Certificate , rejectOnException ) where -import Network.TLS.Context +import Network.TLS.Context.Internal import Network.TLS.Struct +import Network.TLS.X509 import Control.Monad.State import Control.Exception (SomeException) @@ -26,5 +27,5 @@ certificateRejected CertificateRejectUnknownCA = certificateRejected (CertificateRejectOther s) = throwCore $ Error_Protocol ("certificate rejected: " ++ s, True, CertificateUnknown) -rejectOnException :: SomeException -> IO TLSCertificateUsage +rejectOnException :: SomeException -> IO CertificateUsage rejectOnException e = return $ CertificateUsageReject $ CertificateRejectOther $ show e diff --git a/core/Network/TLS/Handshake/Client.hs b/core/Network/TLS/Handshake/Client.hs index 1ba28d2..59e32d9 100644 --- a/core/Network/TLS/Handshake/Client.hs +++ b/core/Network/TLS/Handshake/Client.hs @@ -8,10 +8,12 @@ -- module Network.TLS.Handshake.Client ( handshakeClient + , handshakeClientWith ) where import Network.TLS.Crypto -import Network.TLS.Context +import Network.TLS.Context.Internal +import Network.TLS.Parameters import Network.TLS.Struct import Network.TLS.Cipher import Network.TLS.Compression @@ -41,6 +43,10 @@ import Network.TLS.Handshake.Signature import Network.TLS.Handshake.Key import Network.TLS.Handshake.State +handshakeClientWith :: ClientParams -> Context -> Handshake -> IO () +handshakeClientWith cparams ctx HelloRequest = handshakeClient cparams ctx +handshakeClientWith _ _ _ = throwCore $ Error_Protocol ("unexpected handshake message received in handshakeClientWith", True, HandshakeFailure) + -- client part of handshake. send a bunch of handshake of client -- values intertwined with response from the server. handshakeClient :: ClientParams -> Context -> IO () @@ -50,31 +56,32 @@ handshakeClient cparams ctx = do recvServerHello sentExtensions sessionResuming <- usingState_ ctx isSessionResuming if sessionResuming - then sendChangeCipherAndFinish ctx ClientRole + then sendChangeCipherAndFinish sendMaybeNPN ctx ClientRole else do sendClientData cparams ctx - sendChangeCipherAndFinish ctx ClientRole + sendChangeCipherAndFinish sendMaybeNPN ctx ClientRole recvChangeCipherAndFinish ctx handshakeTerminate ctx - where params = ctxParams ctx - ciphers = pCiphers params - compressions = pCompressions params + where ciphers = ctxCiphers ctx + compressions = supportedCompressions $ ctxSupported ctx getExtensions = sequence [sniExtension,secureReneg,npnExtention] >>= return . catMaybes toExtensionRaw :: Extension e => e -> ExtensionRaw toExtensionRaw ext = (extensionID ext, extensionEncode ext) secureReneg = - if pUseSecureRenegotiation params + if supportedSecureRenegotiation $ ctxSupported ctx then usingState_ ctx (getVerifiedData ClientRole) >>= \vd -> return $ Just $ toExtensionRaw $ SecureRenegotiation vd Nothing else return Nothing - npnExtention = if isJust $ onNPNServerSuggest cparams + npnExtention = if isJust $ onNPNServerSuggest $ clientHooks cparams then return $ Just $ toExtensionRaw $ NextProtocolNegotiation [] else return Nothing - sniExtension = return ((\h -> toExtensionRaw $ ServerName [(ServerNameHostName h)]) <$> clientUseServerName cparams) + sniExtension = if clientUseServerNameIndication cparams + then return $ Just $ toExtensionRaw $ ServerName [ServerNameHostName $ fst $ clientServerIdentification cparams] + else return Nothing sendClientHello = do crand <- getStateRNG ctx 32 >>= return . ClientRandom let clientSession = Session . maybe Nothing (Just . fst) $ clientWantSessionResume cparams - highestVer = maximum $ pAllowedVersions params + highestVer = maximum $ supportedVersions $ ctxSupported ctx extensions <- getExtensions startHandshake ctx highestVer crand usingState_ ctx $ setVersionIfUnset highestVer @@ -84,6 +91,18 @@ handshakeClient cparams ctx = do ] return $ map fst extensions + sendMaybeNPN = do + suggest <- usingState_ ctx $ getServerNextProtocolSuggest + case (onNPNServerSuggest $ clientHooks cparams, suggest) of + -- client offered, server picked up. send NPN handshake. + (Just io, Just protos) -> do proto <- liftIO $ io protos + sendPacket ctx (Handshake [HsNextProtocolNegotiation proto]) + usingState_ ctx $ setNegotiatedProtocol proto + -- client offered, server didn't pick up. do nothing. + (Just _, Nothing) -> return () + -- client didn't offer. do nothing. + (Nothing, _) -> return () + recvServerHello sentExts = runRecvState ctx (RecvStateHandshake $ onServerHello ctx cparams sentExts) -- | send client Data after receiving all server data (hello/certificates/key). @@ -108,7 +127,7 @@ sendClientData cparams ctx = sendCertificate >> sendClientKeyXchg >> sendCertifi return () Just req -> do - certChain <- liftIO $ onCertificateRequest cparams req `catchException` + certChain <- liftIO $ (onCertificateRequest $ clientHooks cparams) req `catchException` throwMiscErrorOnException "certificate request callback failed" usingHState ctx $ setClientCertSent False @@ -176,7 +195,7 @@ sendClientData cparams ctx = sendCertificate >> sendClientKeyXchg >> sendCertifi malg <- case usedVersion of TLS12 -> do Just (_, Just hashSigs, _) <- usingHState ctx $ getClientCertRequest - let suppHashSigs = pHashSignatures $ ctxParams ctx + let suppHashSigs = supportedHashSignatures $ ctxSupported ctx hashSigs' = filter (\ a -> a `elem` hashSigs) suppHashSigs when (null hashSigs') $ @@ -218,14 +237,14 @@ throwMiscErrorOnException msg e = onServerHello :: Context -> ClientParams -> [ExtensionID] -> Handshake -> IO (RecvState IO) onServerHello ctx cparams sentExts (ServerHello rver serverRan serverSession cipher compression exts) = do when (rver == SSL2) $ throwCore $ Error_Protocol ("ssl2 is not supported", True, ProtocolVersion) - case find ((==) rver) allowedvers of + case find ((==) rver) (supportedVersions $ ctxSupported ctx) of Nothing -> throwCore $ Error_Protocol ("server version " ++ show rver ++ " is not supported", True, ProtocolVersion) Just _ -> return () -- find the compression and cipher methods that the server want to use. - cipherAlg <- case find ((==) cipher . cipherID) ciphers of + cipherAlg <- case find ((==) cipher . cipherID) (ctxCiphers ctx) of Nothing -> throwCore $ Error_Protocol ("server choose unknown cipher", True, HandshakeFailure) Just alg -> return alg - compressAlg <- case find ((==) compression . compressionID) compressions of + compressAlg <- case find ((==) compression . compressionID) (supportedCompressions $ ctxSupported ctx) of Nothing -> throwCore $ Error_Protocol ("server choose unknown compression", True, HandshakeFailure) Just alg -> return alg @@ -251,25 +270,25 @@ onServerHello ctx cparams sentExts (ServerHello rver serverRan serverSession cip _ -> return () case resumingSession of - Nothing -> return $ RecvStateHandshake (processCertificate ctx) + Nothing -> return $ RecvStateHandshake (processCertificate cparams ctx) Just sessionData -> do usingHState ctx (setMasterSecret rver ClientRole $ sessionSecret sessionData) return $ RecvStateNext expectChangeCipher - where params = ctxParams ctx - allowedvers = pAllowedVersions params - ciphers = pCiphers params - compressions = pCompressions params onServerHello _ _ _ p = unexpected (show p) (Just "server hello") -processCertificate :: Context -> Handshake -> IO (RecvState IO) -processCertificate ctx (Certificates certs) = do - usage <- liftIO $ catchException (onCertificatesRecv params certs) rejectOnException +processCertificate :: ClientParams -> Context -> Handshake -> IO (RecvState IO) +processCertificate cparams ctx (Certificates certs) = do + usage <- catchException (wrapCertificateChecks <$> checkCert) rejectOnException case usage of CertificateUsageAccept -> return () CertificateUsageReject reason -> certificateRejected reason return $ RecvStateHandshake (processServerKeyExchange ctx) - where params = ctxParams ctx -processCertificate ctx p = processServerKeyExchange ctx p + where shared = clientShared cparams + checkCert = (onServerCertificate $ clientHooks cparams) (sharedCAStore shared) + (sharedValidationCache shared) + (clientServerIdentification cparams) + certs +processCertificate _ ctx p = processServerKeyExchange ctx p expectChangeCipher :: Packet -> IO (RecvState IO) expectChangeCipher ChangeCipherSpec = return $ RecvStateHandshake expectFinish diff --git a/core/Network/TLS/Handshake/Common.hs b/core/Network/TLS/Handshake/Common.hs index 3e0c675..36a0e0a 100644 --- a/core/Network/TLS/Handshake/Common.hs +++ b/core/Network/TLS/Handshake/Common.hs @@ -16,7 +16,8 @@ module Network.TLS.Handshake.Common import Control.Concurrent.MVar -import Network.TLS.Context +import Network.TLS.Parameters +import Network.TLS.Context.Internal import Network.TLS.Session import Network.TLS.Struct import Network.TLS.IO @@ -45,8 +46,8 @@ unexpected msg expected = throwCore $ Error_Packet_unexpected msg (maybe "" (" e newSession :: Context -> IO Session newSession ctx - | pUseSession $ ctxParams ctx = getStateRNG ctx 32 >>= return . Session . Just - | otherwise = return $ Session Nothing + | supportedSession $ ctxSupported ctx = getStateRNG ctx 32 >>= return . Session . Just + | otherwise = return $ Session Nothing -- | when a new handshake is done, wrap up & clean up. handshakeTerminate :: Context -> IO () @@ -56,7 +57,7 @@ handshakeTerminate ctx = do case session of Session (Just sessionId) -> do sessionData <- getSessionData ctx - withSessionManager (ctxParams ctx) (\s -> liftIO $ sessionEstablish s sessionId (fromJust "session-data" sessionData)) + liftIO $ sessionEstablish (sharedSessionManager $ ctxShared ctx) sessionId (fromJust "session-data" sessionData) _ -> return () -- forget all handshake data now and reset bytes counters. liftIO $ modifyMVar_ (ctxHandshake ctx) (return . const Nothing) @@ -65,24 +66,14 @@ handshakeTerminate ctx = do setEstablished ctx True return () -sendChangeCipherAndFinish :: Context -> Role -> IO () -sendChangeCipherAndFinish ctx role = do +sendChangeCipherAndFinish :: IO () -- ^ message possibly sent between ChangeCipherSpec and Finished. + -> Context + -> Role + -> IO () +sendChangeCipherAndFinish betweenCall ctx role = do sendPacket ctx ChangeCipherSpec - - when (role == ClientRole) $ do - let cparams = getClientParams $ ctxParams ctx - suggest <- usingState_ ctx $ getServerNextProtocolSuggest - case (onNPNServerSuggest cparams, suggest) of - -- client offered, server picked up. send NPN handshake. - (Just io, Just protos) -> do proto <- liftIO $ io protos - sendPacket ctx (Handshake [HsNextProtocolNegotiation proto]) - usingState_ ctx $ setNegotiatedProtocol proto - -- client offered, server didn't pick up. do nothing. - (Just _, Nothing) -> return () - -- client didn't offer. do nothing. - (Nothing, _) -> return () + betweenCall liftIO $ contextFlush ctx - cf <- usingState_ ctx getVersion >>= \ver -> usingHState ctx $ getHandshakeDigest ver role sendPacket ctx (Handshake [Finished cf]) liftIO $ contextFlush ctx diff --git a/core/Network/TLS/Handshake/Key.hs b/core/Network/TLS/Handshake/Key.hs index 3ea65b6..fd6eb9c 100644 --- a/core/Network/TLS/Handshake/Key.hs +++ b/core/Network/TLS/Handshake/Key.hs @@ -22,7 +22,7 @@ import Network.TLS.Handshake.State import Network.TLS.State (withRNG, getVersion) import Network.TLS.Crypto import Network.TLS.Types -import Network.TLS.Context +import Network.TLS.Context.Internal {- if the RSA encryption fails we just return an empty bytestring, and let the protocol - fail by itself; however it would be probably better to just report it since it's an internal problem. diff --git a/core/Network/TLS/Handshake/Process.hs b/core/Network/TLS/Handshake/Process.hs index 030ffb6..ff95a1a 100644 --- a/core/Network/TLS/Handshake/Process.hs +++ b/core/Network/TLS/Handshake/Process.hs @@ -23,7 +23,7 @@ import Network.TLS.Util import Network.TLS.Packet import Network.TLS.Struct import Network.TLS.State -import Network.TLS.Context +import Network.TLS.Context.Internal import Network.TLS.Crypto import Network.TLS.Handshake.State import Network.TLS.Handshake.Key diff --git a/core/Network/TLS/Handshake/Server.hs b/core/Network/TLS/Handshake/Server.hs index f2df07e..b66279b 100644 --- a/core/Network/TLS/Handshake/Server.hs +++ b/core/Network/TLS/Handshake/Server.hs @@ -11,7 +11,8 @@ module Network.TLS.Handshake.Server , handshakeServerWith ) where -import Network.TLS.Context +import Network.TLS.Parameters +import Network.TLS.Context.Internal import Network.TLS.Session import Network.TLS.Struct import Network.TLS.Cipher @@ -80,7 +81,7 @@ handshakeServer sparams ctx = liftIO $ do handshakeServerWith :: ServerParams -> Context -> Handshake -> IO () handshakeServerWith sparams ctx clientHello@(ClientHello clientVersion _ clientSession ciphers compressions exts _) = do -- check if policy allow this new handshake to happens - handshakeAuthorized <- withMeasure ctx (onHandshake $ ctxParams ctx) + handshakeAuthorized <- withMeasure ctx (onHandshake $ ctxCommonHooks ctx) unless handshakeAuthorized (throwCore $ Error_HandshakePolicy "server: handshake denied") updateMeasure ctx incrementNbHandshakes @@ -88,7 +89,7 @@ handshakeServerWith sparams ctx clientHello@(ClientHello clientVersion _ clientS processHandshake ctx clientHello when (clientVersion == SSL2) $ throwCore $ Error_Protocol ("ssl2 is not supported", True, ProtocolVersion) - chosenVersion <- case findHighestVersionFrom clientVersion (pAllowedVersions params) of + chosenVersion <- case findHighestVersionFrom clientVersion (supportedVersions $ ctxSupported ctx) of Nothing -> throwCore $ Error_Protocol ("client version " ++ show clientVersion ++ " is not supported", True, ProtocolVersion) Just v -> return v @@ -98,8 +99,8 @@ handshakeServerWith sparams ctx clientHello@(ClientHello clientVersion _ clientS Error_Protocol ("no compression in common with the client", True, HandshakeFailure) let ciphersFilteredVersion = filter (cipherAllowedForVersion chosenVersion) commonCiphers - usedCipher = (onCipherChoosing sparams) chosenVersion ciphersFilteredVersion - creds = credentialsGet params + usedCipher = (onCipherChoosing $ serverHooks sparams) chosenVersion ciphersFilteredVersion + creds = sharedCredentials $ ctxShared ctx cred <- case cipherKeyExchange usedCipher of CipherKeyExchange_RSA -> return $ credentialsFindForDecrypting creds CipherKeyExchange_DH_Anon -> return $ Nothing @@ -108,20 +109,19 @@ handshakeServerWith sparams ctx clientHello@(ClientHello clientVersion _ clientS _ -> throwCore $ Error_Protocol ("key exchange algorithm not implemented", True, HandshakeFailure) resumeSessionData <- case clientSession of - (Session (Just clientSessionId)) -> withSessionManager params (\s -> liftIO $ sessionResume s clientSessionId) + (Session (Just clientSessionId)) -> liftIO $ sessionResume (sharedSessionManager $ ctxShared ctx) clientSessionId (Session Nothing) -> return Nothing doHandshake sparams cred ctx chosenVersion usedCipher usedCompression clientSession resumeSessionData exts where - params = ctxParams ctx commonCipherIDs = intersect ciphers (map cipherID $ ctxCiphers ctx) commonCiphers = filter (flip elem commonCipherIDs . cipherID) (ctxCiphers ctx) - commonCompressions = compressionIntersectID (pCompressions params) compressions + commonCompressions = compressionIntersectID (supportedCompressions $ ctxSupported ctx) compressions usedCompression = head commonCompressions -handshakeServerWith _ _ _ = fail "unexpected handshake type received. expecting client hello" +handshakeServerWith _ _ _ = throwCore $ Error_Protocol ("unexpected handshake message received in handshakeServerWith", True, HandshakeFailure) doHandshake :: ServerParams -> Maybe Credential -> Context -> Version -> Cipher -> Compression -> Session -> Maybe SessionData @@ -133,13 +133,13 @@ doHandshake sparams mcred ctx chosenVersion usedCipher usedCompression clientSes liftIO $ contextFlush ctx -- Receive client info until client Finished. recvClientData sparams ctx - sendChangeCipherAndFinish ctx ServerRole + sendChangeCipherAndFinish (return ()) ctx ServerRole Just sessionData -> do usingState_ ctx (setSession clientSession True) serverhello <- makeServerHello clientSession sendPacket ctx $ Handshake [serverhello] usingHState ctx $ setMasterSecret chosenVersion ServerRole $ sessionSecret sessionData - sendChangeCipherAndFinish ctx ServerRole + sendChangeCipherAndFinish (return ()) ctx ServerRole recvChangeCipherAndFinish ctx handshakeTerminate ctx where @@ -169,7 +169,7 @@ doHandshake sparams mcred ctx chosenVersion usedCipher usedCompression clientSes else return [] nextProtocols <- if clientRequestedNPN - then liftIO $ onSuggestNextProtocols sparams + then liftIO $ onSuggestNextProtocols $ serverHooks sparams else return Nothing npnExt <- case nextProtocols of Just protos -> do usingState_ ctx $ do setExtensionNPN True @@ -212,7 +212,7 @@ doHandshake sparams mcred ctx chosenVersion usedCipher usedCompression clientSes let certTypes = [ CertificateType_RSA_Sign ] hashSigs = if usedVersion < TLS12 then Nothing - else Just (pHashSignatures $ ctxParams ctx) + else Just (supportedHashSignatures $ ctxSupported ctx) creq = CertRequest certTypes hashSigs (map extractCAname $ serverCACertificates sparams) usingHState ctx $ setCertReqSent True @@ -240,7 +240,7 @@ doHandshake sparams mcred ctx chosenVersion usedCipher usedCompression clientSes usedVersion <- usingState_ ctx getVersion let mhash = case usedVersion of - TLS12 -> case filter ((==) sigAlg . snd) $ pHashSignatures $ ctxParams ctx of + TLS12 -> case filter ((==) sigAlg . snd) $ supportedHashSignatures $ ctxSupported ctx of [] -> error ("no hash signature for " ++ show sigAlg) x:_ -> Just (fst x) _ -> Nothing @@ -269,7 +269,7 @@ recvClientData sparams ctx = runRecvState ctx (RecvStateHandshake processClientC -- Call application callback to see whether the -- certificate chain is acceptable. -- - usage <- liftIO $ catchException (onClientCertificate sparams certs) rejectOnException + usage <- liftIO $ catchException (onClientCertificate (serverHooks sparams) certs) rejectOnException case usage of CertificateUsageAccept -> return () CertificateUsageReject reason -> certificateRejected reason @@ -320,7 +320,7 @@ recvClientData sparams ctx = runRecvState ctx (RecvStateHandshake processClientC -- the signature is wrong. In either case, -- ask the application if it wants to -- proceed, we will do that. - res <- liftIO $ onUnverifiedClientCert sparams + res <- liftIO $ onUnverifiedClientCert (serverHooks sparams) if res then do -- When verification fails, but the diff --git a/core/Network/TLS/Handshake/Signature.hs b/core/Network/TLS/Handshake/Signature.hs index 175bda8..e055c7d 100644 --- a/core/Network/TLS/Handshake/Signature.hs +++ b/core/Network/TLS/Handshake/Signature.hs @@ -18,7 +18,7 @@ module Network.TLS.Handshake.Signature import Crypto.PubKey.HashDescr import Network.TLS.Crypto -import Network.TLS.Context +import Network.TLS.Context.Internal import Network.TLS.Struct import Network.TLS.Packet (generateCertificateVerify_SSL, encodeSignedDHParams) import Network.TLS.State diff --git a/core/Network/TLS/Hooks.hs b/core/Network/TLS/Hooks.hs index 54e092f..4413967 100644 --- a/core/Network/TLS/Hooks.hs +++ b/core/Network/TLS/Hooks.hs @@ -7,13 +7,13 @@ -- module Network.TLS.Hooks ( Logging(..) - , defaultLogging , Hooks(..) , defaultHooks ) where import qualified Data.ByteString as B import Network.TLS.Struct (Header, Handshake(..)) +import Data.Default.Class -- | Hooks for logging data Logging = Logging @@ -31,6 +31,9 @@ defaultLogging = Logging , loggingIORecv = (\_ _ -> return ()) } +instance Default Logging where + def = defaultLogging + -- | A collection of hooks actions. data Hooks = Hooks { hookRecvHandshake :: Handshake -> IO Handshake @@ -41,3 +44,5 @@ defaultHooks = Hooks { hookRecvHandshake = \hs -> return hs } +instance Default Hooks where + def = defaultHooks diff --git a/core/Network/TLS/IO.hs b/core/Network/TLS/IO.hs index 1eb1050..39c0ba4 100644 --- a/core/Network/TLS/IO.hs +++ b/core/Network/TLS/IO.hs @@ -13,10 +13,11 @@ module Network.TLS.IO , recvPacket ) where -import Network.TLS.Context +import Network.TLS.Context.Internal import Network.TLS.Struct import Network.TLS.Record import Network.TLS.Packet +import Network.TLS.Hooks import Network.TLS.Sending import Network.TLS.Receiving import qualified Data.ByteString as B diff --git a/core/Network/TLS/Parameters.hs b/core/Network/TLS/Parameters.hs index 79f806e..c1544ee 100644 --- a/core/Network/TLS/Parameters.hs +++ b/core/Network/TLS/Parameters.hs @@ -5,33 +5,23 @@ -- Stability : experimental -- Portability : unknown -- --- extension RecordWildCards only needed because of some GHC bug --- relative to insufficient polymorphic field -{-# LANGUAGE RecordWildCards #-} module Network.TLS.Parameters ( - -- * Parameters - Params(..) - , RoleParams(..) - , ClientParams(..) + ClientParams(..) , ServerParams(..) - , updateClientParams - , updateServerParams - , Logging(..) - , SessionID - , SessionData(..) + , CommonParams + , ClientHooks(..) + , ServerHooks(..) + , CommonHooks(..) + , Supported(..) + , Shared(..) + -- * special default + , defaultParamsClient + -- * Parameters , MaxFragmentEnum(..) - , Measurement(..) + , Logging(..) , CertificateUsage(..) , CertificateRejectReason(..) - , defaultLogging - , defaultParamsClient - , defaultParamsServer - , withSessionManager - , setSessionManager - , getClientParams - , getServerParams - , credentialsGet ) where import Network.BSD (HostName) @@ -48,15 +38,137 @@ import Network.TLS.Hooks import Network.TLS.Measurement import Network.TLS.X509 import Data.Monoid -import Data.List (intercalate) +import Data.Default.Class import qualified Data.ByteString as B -data ClientParams = ClientParams - { clientUseMaxFragmentLength :: Maybe MaxFragmentEnum - , clientUseServerName :: Maybe HostName - , clientWantSessionResume :: Maybe (SessionID, SessionData) -- ^ try to establish a connection using this session. +type CommonParams = (Supported, Shared, CommonHooks) - -- | This action is called when the server sends a +data ClientParams = ClientParams + { clientUseMaxFragmentLength :: Maybe MaxFragmentEnum + -- | Define the name of the server, along with an extra service identification blob. + -- this is important that the hostname part is properly filled for security reason, + -- as it allow to properly associate the remote side with the given certificate + -- during a handshake. + -- + -- The extra blob is useful to differentiate services running on the same host, but that + -- might have different certificates given. It's only used as part of the X509 validation + -- infrastructure. + , clientServerIdentification :: (HostName, Bytes) + -- | Allow the use of the Server Name Indication TLS extension during handshake, which allow + -- the client to specify which host name, it's trying to access. This is useful to distinguish + -- CNAME aliasing (e.g. web virtual host). + , clientUseServerNameIndication :: Bool + -- | try to establish a connection using this session. + , clientWantSessionResume :: Maybe (SessionID, SessionData) + , clientShared :: Shared + , clientHooks :: ClientHooks + , clientCommonHooks :: CommonHooks + , clientSupported :: Supported + } deriving (Show) + +defaultParamsClient :: HostName -> Bytes -> ClientParams +defaultParamsClient serverName serverId = ClientParams + { clientWantSessionResume = Nothing + , clientUseMaxFragmentLength = Nothing + , clientServerIdentification = (serverName, serverId) + , clientUseServerNameIndication = True + , clientShared = def + , clientHooks = def + , clientCommonHooks = def + , clientSupported = def + } + +data ServerParams = ServerParams + { -- | request a certificate from client. + serverWantClientCert :: Bool + + -- | This is a list of certificates from which the + -- disinguished names are sent in certificate request + -- messages. For TLS1.0, it should not be empty. + , serverCACertificates :: [SignedCertificate] + + -- | Server Optional Diffie Hellman parameters. If this value is not + -- properly set, no Diffie Hellman key exchange will take place. + , serverDHEParams :: Maybe DHParams + + , serverShared :: Shared + , serverHooks :: ServerHooks + , serverCommonHooks :: CommonHooks + , serverSupported :: Supported + } deriving (Show) + +defaultParamsServer :: ServerParams +defaultParamsServer = ServerParams + { serverWantClientCert = False + , serverCACertificates = [] + , serverDHEParams = Nothing + , serverHooks = def + , serverShared = def + , serverCommonHooks = def + , serverSupported = def + } + +instance Default ServerParams where + def = defaultParamsServer + +-- | List all the supported algorithms, versions, ciphers, etc supported. +data Supported = Supported + { + -- | Supported Versions by this context + -- On the client side, the highest version will be used to establish the connection. + -- On the server side, the highest version that is less or equal than the client version will be chosed. + supportedVersions :: [Version] + -- | Supported cipher methods + , supportedCiphers :: [Cipher] + -- | supported compressions methods + , supportedCompressions :: [Compression] + -- | All supported hash/signature algorithms pair for client + -- certificate verification, ordered by decreasing priority. + , supportedHashSignatures :: [HashAndSignatureAlgorithm] + -- | Set if we support secure renegotiation. + , supportedSecureRenegotiation :: Bool + -- | Set if we support session. + , supportedSession :: Bool + } deriving (Show,Eq) + +defaultSupported :: Supported +defaultSupported = Supported + { supportedVersions = [TLS10,TLS11,TLS12] + , supportedCiphers = [] + , supportedCompressions = [nullCompression] + , supportedHashSignatures = [ (Struct.HashSHA512, SignatureRSA) + , (Struct.HashSHA384, SignatureRSA) + , (Struct.HashSHA256, SignatureRSA) + , (Struct.HashSHA224, SignatureRSA) + , (Struct.HashSHA1, SignatureDSS) + ] + , supportedSecureRenegotiation = True + , supportedSession = True + } + +instance Default Supported where + def = defaultSupported + +data Shared = Shared + { sharedCredentials :: Credentials + , sharedSessionManager :: SessionManager + , sharedCAStore :: CertificateStore + , sharedValidationCache :: ValidationCache + } + +instance Show Shared where + show _ = "Shared" +instance Default Shared where + def = Shared + { sharedCAStore = mempty + , sharedCredentials = mempty + , sharedSessionManager = noSessionManager + , sharedValidationCache = def + } + +-- | A set of callbacks run by the clients for various corners of TLS establishment +data ClientHooks = ClientHooks + { -- | This action is called when the server sends a -- certificate request. The parameter is the information -- from the request. The action should select a certificate -- chain of one of the given certificate types where the @@ -75,24 +187,32 @@ data ClientParams = ClientParams -- Returning a certificate chain not matching the -- distinguished names may lead to problems or not, -- depending whether the server accepts it. - , onCertificateRequest :: ([CertificateType], + onCertificateRequest :: ([CertificateType], Maybe [HashAndSignatureAlgorithm], [DistinguishedName]) -> IO (Maybe (CertificateChain, PrivKey)) , onNPNServerSuggest :: Maybe ([B.ByteString] -> IO B.ByteString) + , onServerCertificate :: CertificateStore -> ValidationCache -> ServiceID -> CertificateChain -> IO [FailedReason] } -data ServerParams = ServerParams - { serverWantClientCert :: Bool -- ^ request a certificate from client. +defaultClientHooks :: ClientHooks +defaultClientHooks = ClientHooks + { onCertificateRequest = \ _ -> return Nothing + , onNPNServerSuggest = Nothing + , onServerCertificate = validateDefault + } - -- | This is a list of certificates from which the - -- disinguished names are sent in certificate request - -- messages. For TLS1.0, it should not be empty. - , serverCACertificates :: [SignedCertificate] +instance Show ClientHooks where + show _ = "ClientHooks" +instance Default ClientHooks where + def = defaultClientHooks +-- | A set of callbacks run by the server for various corners of the TLS establishment +data ServerHooks = ServerHooks + { -- | This action is called when a client certificate chain -- is received from the client. When it returns a -- CertificateUsageReject value, the handshake is aborted. - , onClientCertificate :: CertificateChain -> IO CertificateUsage + onClientCertificate :: CertificateChain -> IO CertificateUsage -- | This action is called when the client certificate -- cannot be verified. A 'Nothing' argument indicates a @@ -109,113 +229,35 @@ data ServerParams = ServerParams -- The client cipher list cannot be empty. , onCipherChoosing :: Version -> [Cipher] -> Cipher - -- | Server Optional Diffie Hellman parameters - , serverDHEParams :: Maybe DHParams - -- | suggested next protocols accoring to the next protocol negotiation extension. - , onSuggestNextProtocols :: IO (Maybe [B.ByteString]) + , onSuggestNextProtocols :: IO (Maybe [B.ByteString]) } -data RoleParams = Client ClientParams | Server ServerParams +defaultServerHooks :: ServerHooks +defaultServerHooks = ServerHooks + { onCipherChoosing = \_ -> head + , onClientCertificate = \_ -> return $ CertificateUsageReject $ CertificateRejectOther "no client certificates expected" + , onUnverifiedClientCert = return False + , onSuggestNextProtocols = return Nothing + } -data Params = Params - { pAllowedVersions :: [Version] -- ^ allowed versions that we can use. - -- the default version used for connection is the highest version in the list - , pCiphers :: [Cipher] -- ^ all ciphers supported ordered by priority. - , pCompressions :: [Compression] -- ^ all compression supported ordered by priority. - , pHashSignatures :: [HashAndSignatureAlgorithm] -- ^ All supported hash/signature algorithms pair for client certificate verification, ordered by decreasing priority. - , pUseSecureRenegotiation :: Bool -- ^ notify that we want to use secure renegotation - , pUseSession :: Bool -- ^ generate new session if specified - , pCertificates :: Maybe (CertificateChain, Maybe PrivKey) -- ^ the cert chain for this context with the associated keys if any. - , pCredentials :: Credentials -- ^ credentials - , pLogging :: Logging -- ^ callback for logging +instance Show ServerHooks where + show _ = "ClientHooks" +instance Default ServerHooks where + def = defaultServerHooks + +data CommonHooks = CommonHooks + { onCertificatesRecv :: CertificateChain -> IO CertificateUsage -- ^ callback to verify received cert chain. , onHandshake :: Measurement -> IO Bool -- ^ callback on a beggining of handshake - , onCertificatesRecv :: CertificateChain -> IO CertificateUsage -- ^ callback to verify received cert chain. - , pSessionManager :: SessionManager - , roleParams :: RoleParams - } -{-# DEPRECATED pCertificates "use pCredentials instead of pCertificates. removed in tls-1.3" #-} - -credentialsGet :: Params -> Credentials -credentialsGet params = pCredentials params `mappend` - case pCertificates params of - Just (cchain, Just priv) -> Credentials [(cchain, priv)] - _ -> Credentials [] - --- | Set a new session manager in a parameters structure. -setSessionManager :: SessionManager -> Params -> Params -setSessionManager manager (Params {..}) = Params { pSessionManager = manager, .. } - -withSessionManager :: Params -> (SessionManager -> a) -> a -withSessionManager (Params { pSessionManager = man }) f = f man - -getClientParams :: Params -> ClientParams -getClientParams params = - case roleParams params of - Client clientParams -> clientParams - _ -> error "server params in client context" - -getServerParams :: Params -> ServerParams -getServerParams params = - case roleParams params of - Server serverParams -> serverParams - _ -> error "client params in server context" - -defaultParamsClient :: Params -defaultParamsClient = Params - { pAllowedVersions = [TLS10,TLS11,TLS12] - , pCiphers = [] - , pCompressions = [nullCompression] - , pHashSignatures = [ (Struct.HashSHA512, SignatureRSA) - , (Struct.HashSHA384, SignatureRSA) - , (Struct.HashSHA256, SignatureRSA) - , (Struct.HashSHA224, SignatureRSA) - , (Struct.HashSHA1, SignatureDSS) - ] - , pUseSecureRenegotiation = True - , pUseSession = True - , pCertificates = Nothing - , pCredentials = mempty - , pLogging = defaultLogging - , onHandshake = (\_ -> return True) - , onCertificatesRecv = (\_ -> return CertificateUsageAccept) - , pSessionManager = noSessionManager - , roleParams = Client $ ClientParams - { clientWantSessionResume = Nothing - , clientUseMaxFragmentLength = Nothing - , clientUseServerName = Nothing - , onCertificateRequest = \ _ -> return Nothing - , onNPNServerSuggest = Nothing - } + , logging :: Logging -- ^ callback for logging } -defaultParamsServer :: Params -defaultParamsServer = defaultParamsClient { roleParams = Server role } - where role = ServerParams - { serverWantClientCert = False - , onCipherChoosing = \_ -> head - , serverCACertificates = [] - , serverDHEParams = Nothing - , onClientCertificate = \ _ -> return $ CertificateUsageReject $ CertificateRejectOther "no client certificates expected" - , onUnverifiedClientCert = return False - , onSuggestNextProtocols = return Nothing - } +instance Show CommonHooks where + show _ = "CommonHooks" -updateRoleParams :: (ClientParams -> ClientParams) -> (ServerParams -> ServerParams) -> Params -> Params -updateRoleParams fc fs params = case roleParams params of - Client c -> params { roleParams = Client (fc c) } - Server s -> params { roleParams = Server (fs s) } - -updateClientParams :: (ClientParams -> ClientParams) -> Params -> Params -updateClientParams f = updateRoleParams f id - -updateServerParams :: (ServerParams -> ServerParams) -> Params -> Params -updateServerParams f = updateRoleParams id f - -instance Show Params where - show p = "Params { " ++ (intercalate "," $ map (\(k,v) -> k ++ "=" ++ v) - [ ("allowedVersions", show $ pAllowedVersions p) - , ("ciphers", show $ pCiphers p) - , ("compressions", show $ pCompressions p) - , ("certificates", show $ pCertificates p) - ]) ++ " }" +instance Default CommonHooks where + def = CommonHooks + { onCertificatesRecv = \_ -> return CertificateUsageAccept + , logging = def + , onHandshake = \_ -> return True + } diff --git a/core/Network/TLS/Receiving.hs b/core/Network/TLS/Receiving.hs index cdd2a5e..f835f6d 100644 --- a/core/Network/TLS/Receiving.hs +++ b/core/Network/TLS/Receiving.hs @@ -16,7 +16,7 @@ import Control.Monad.State import Control.Monad.Error import Control.Concurrent.MVar -import Network.TLS.Context +import Network.TLS.Context.Internal import Network.TLS.Struct import Network.TLS.Record import Network.TLS.Packet diff --git a/core/Network/TLS/Sending.hs b/core/Network/TLS/Sending.hs index 1977c7c..178ca4f 100644 --- a/core/Network/TLS/Sending.hs +++ b/core/Network/TLS/Sending.hs @@ -23,7 +23,8 @@ import Network.TLS.Cap import Network.TLS.Struct import Network.TLS.Record import Network.TLS.Packet -import Network.TLS.Context +import Network.TLS.Context.Internal +import Network.TLS.Parameters import Network.TLS.State import Network.TLS.Handshake.State import Network.TLS.Cipher @@ -67,7 +68,7 @@ writePacket ctx pkt = do -- so we use cstIV as is, however in other case we generate an explicit IV prepareRecord :: Context -> RecordM a -> IO (Either TLSError a) prepareRecord ctx f = do - ver <- usingState_ ctx (getVersionWithDefault $ maximum $ pAllowedVersions $ ctxParams ctx) + ver <- usingState_ ctx (getVersionWithDefault $ maximum $ supportedVersions $ ctxSupported ctx) txState <- readMVar $ ctxTxState ctx let sz = case stCipher $ txState of Nothing -> 0 diff --git a/core/Network/TLS/Types.hs b/core/Network/TLS/Types.hs index 31a6ca7..0f3be1f 100644 --- a/core/Network/TLS/Types.hs +++ b/core/Network/TLS/Types.hs @@ -32,7 +32,7 @@ data SessionData = SessionData { sessionVersion :: Version , sessionCipher :: CipherID , sessionSecret :: ByteString - } + } deriving (Show,Eq) -- | Cipher identification type CipherID = Word16 diff --git a/core/Network/TLS/X509.hs b/core/Network/TLS/X509.hs index 8e37878..197637b 100644 --- a/core/Network/TLS/X509.hs +++ b/core/Network/TLS/X509.hs @@ -16,9 +16,18 @@ module Network.TLS.X509 , getCertificateChainLeaf , CertificateRejectReason(..) , CertificateUsage(..) + , CertificateStore + , ValidationCache + , exceptionValidationCache + , validateDefault + , FailedReason + , ServiceID + , wrapCertificateChecks ) where import Data.X509 +import Data.X509.Validation +import Data.X509.CertificateStore isNullCertificateChain :: CertificateChain -> Bool isNullCertificateChain (CertificateChain l) = null l @@ -41,3 +50,10 @@ data CertificateUsage = | CertificateUsageReject CertificateRejectReason -- ^ usage of certificate rejected deriving (Show,Eq) +wrapCertificateChecks :: [FailedReason] -> CertificateUsage +wrapCertificateChecks [] = CertificateUsageAccept +wrapCertificateChecks l + | Expired `elem` l = CertificateUsageReject $ CertificateRejectExpired + | InFuture `elem` l = CertificateUsageReject $ CertificateRejectExpired + | UnknownCA `elem` l = CertificateUsageReject $ CertificateRejectUnknownCA + | otherwise = CertificateUsageReject $ CertificateRejectOther (show l) diff --git a/core/Tests/Connection.hs b/core/Tests/Connection.hs index ef22c18..dcd7403 100644 --- a/core/Tests/Connection.hs +++ b/core/Tests/Connection.hs @@ -14,6 +14,8 @@ import PubKey import PipeChan import Network.TLS import Data.X509 +import Data.X509.Validation +import Data.Default.Class import Control.Applicative import Control.Concurrent.Chan import Control.Concurrent @@ -70,11 +72,11 @@ streamCipher = blockCipher } } -supportedCiphers :: [Cipher] -supportedCiphers = [blockCipher,blockCipherDHE_RSA,blockCipherDHE_DSS,streamCipher] +knownCiphers :: [Cipher] +knownCiphers = [blockCipher,blockCipherDHE_RSA,blockCipherDHE_DSS,streamCipher] -supportedVersions :: [Version] -supportedVersions = [SSL3,TLS10,TLS11,TLS12] +knownVersions :: [Version] +knownVersions = [SSL3,TLS10,TLS11,TLS12] arbitraryPairParams = do (dsaPub, dsaPriv) <- (\(p,r) -> (PubKeyDSA p, PrivKeyDSA r)) <$> arbitraryDSAPair @@ -83,44 +85,54 @@ arbitraryPairParams = do cert <- arbitraryX509WithKey (pub, priv) return (CertificateChain [cert], priv) ) [ (pubKey, privKey), (dsaPub, dsaPriv) ] - connectVersion <- elements supportedVersions - let allowedVersions = [ v | v <- supportedVersions, v <= connectVersion ] + connectVersion <- elements knownVersions + let allowedVersions = [ v | v <- knownVersions, v <= connectVersion ] serAllowedVersions <- (:[]) `fmap` elements allowedVersions serverCiphers <- arbitraryCiphers clientCiphers <- oneof [arbitraryCiphers] `suchThat` (\cs -> or [x `elem` serverCiphers | x <- cs]) secNeg <- arbitrary - --let cred = (CertificateChain [servCert], PrivKeyRSA privKey) - let serverState = defaultParamsServer - { pAllowedVersions = serAllowedVersions - , pCiphers = serverCiphers - , pCredentials = Credentials creds - , pUseSecureRenegotiation = secNeg - , pLogging = logging "server: " - , roleParams = roleParams $ updateServerParams (\sp -> sp { serverDHEParams = Just dhParams }) defaultParamsServer + +-- , pLogging = logging "server: " +-- , pLogging = logging "client: " + + let serverState = def + { serverSupported = def { supportedCiphers = serverCiphers + , supportedVersions = serAllowedVersions + , supportedSecureRenegotiation = secNeg + } + , serverDHEParams = Just dhParams + , serverShared = def { sharedCredentials = Credentials creds } } - let clientState = defaultParamsClient - { pAllowedVersions = allowedVersions - , pCiphers = clientCiphers - , pUseSecureRenegotiation = secNeg - , pLogging = logging "client: " + let clientState = (defaultParamsClient "" B.empty) + { clientSupported = def { supportedCiphers = clientCiphers + , supportedVersions = allowedVersions + , supportedSecureRenegotiation = secNeg + } + , clientShared = def { sharedValidationCache = ValidationCache + { cacheAdd = \_ _ _ -> return () + , cacheQuery = \_ _ _ -> return ValidationCachePass + } + } } return (clientState, serverState) where logging pre = if debug - then defaultLogging { loggingPacketSent = putStrLn . ((pre ++ ">> ") ++) + then def { loggingPacketSent = putStrLn . ((pre ++ ">> ") ++) , loggingPacketRecv = putStrLn . ((pre ++ "<< ") ++) } - else defaultLogging - arbitraryCiphers = resize (length supportedCiphers + 1) $ listOf1 (elements supportedCiphers) + else def + arbitraryCiphers = resize (length knownCiphers + 1) $ listOf1 (elements knownCiphers) -setPairParamsSessionManager :: SessionManager -> (Params, Params) -> (Params, Params) +setPairParamsSessionManager :: SessionManager -> (ClientParams, ServerParams) -> (ClientParams, ServerParams) setPairParamsSessionManager manager (clientState, serverState) = (nc,ns) - where nc = setSessionManager manager clientState - ns = setSessionManager manager serverState + where nc = clientState { clientShared = updateSessionManager $ clientShared clientState } + ns = serverState { serverShared = updateSessionManager $ serverShared serverState } + updateSessionManager shared = shared { sharedSessionManager = manager } -setPairParamsSessionResuming sessionStuff (clientState, serverState) = (nc,serverState) - where nc = updateClientParams (\cparams -> cparams { clientWantSessionResume = Just sessionStuff }) clientState +setPairParamsSessionResuming sessionStuff (clientState, serverState) = + ( clientState { clientWantSessionResume = Just sessionStuff } + , serverState) newPairContext pipe (cParams, sParams) = do let noFlush = return () diff --git a/core/Tests/Tests.hs b/core/Tests/Tests.hs index c7418cb..f5360b6 100644 --- a/core/Tests/Tests.hs +++ b/core/Tests/Tests.hs @@ -72,8 +72,12 @@ prop_handshake_initiate = do prop_handshake_npn_initiate :: PropertyM IO () prop_handshake_npn_initiate = do (clientParam,serverParam) <- pick arbitraryPairParams - let clientParam' = updateClientParams (\cp -> cp { onNPNServerSuggest = Just $ \protos -> return (head protos) }) clientParam - serverParam' = updateServerParams (\sp -> sp { onSuggestNextProtocols = return $ Just [C8.pack "spdy/2", C8.pack "http/1.1"] }) serverParam + let clientParam' = clientParam { clientHooks = (clientHooks clientParam) + { onNPNServerSuggest = Just $ \protos -> return (head protos) } + } + serverParam' = serverParam { serverHooks = (serverHooks serverParam) + { onSuggestNextProtocols = return $ Just [C8.pack "spdy/2", C8.pack "http/1.1"] } + } params' = (clientParam',serverParam') runTLSPipe params' tlsServer tlsClient where tlsServer ctx queue = do diff --git a/core/tls.cabal b/core/tls.cabal index 28474a2..b932ea7 100644 --- a/core/tls.cabal +++ b/core/tls.cabal @@ -37,6 +37,7 @@ Library , cereal >= 0.3 , bytestring , network + , data-default-class , crypto-random >= 0.0 && < 0.1 , crypto-numbers , crypto-pubkey-types >= 0.4 @@ -45,6 +46,7 @@ Library , asn1-encoding , x509 >= 1.4.3 && < 1.5.0 , x509-store + , x509-validation >= 1.5.0 && < 1.6.0 Exposed-modules: Network.TLS Network.TLS.Cipher Network.TLS.Compression @@ -53,6 +55,7 @@ Library Network.TLS.Struct Network.TLS.Core Network.TLS.Context + Network.TLS.Context.Internal Network.TLS.Credentials Network.TLS.Backend Network.TLS.Crypto @@ -102,6 +105,7 @@ Test-Suite test-tls Build-Depends: base >= 3 && < 5 , mtl , cereal >= 0.3 + , data-default-class , QuickCheck >= 2 , test-framework , test-framework-quickcheck2 @@ -109,6 +113,7 @@ Test-Suite test-tls , crypto-pubkey >= 0.2 , bytestring , x509 + , x509-validation , tls , time , crypto-random @@ -122,6 +127,8 @@ Benchmark bench-tls Build-depends: base >= 4 && < 5 , tls , x509 + , x509-validation + , data-default-class , crypto-random , criterion , cprng-aes